import os
import numpy as np

from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim import Momentum
from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import set_seed
import mindspore.nn as nn
import mindspore.common.initializer as weight_init

from src.model_utils.config import config
from src.resnet import conv_variance_scaling_initializer
from src.resnet import resnet50 as resnet
from src.dataset import create_dataset as create_dataset
from src.lr_generator import get_lr
from src.eval_callback import EvalCallBack
from src.metric import DistAccuracy, ClassifyCorrectCell

set_seed(1)


class LossCallBack(LossMonitor):

    def __init__(self, has_trained_epoch=0):
        super(LossCallBack, self).__init__()
        self.has_trained_epoch = has_trained_epoch

    def step_end(self, run_context):
        cb_params = run_context.original_args()
        loss = cb_params.net_outputs

        if isinstance(loss, (tuple, list)):
            if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
                loss = loss[0]

        if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
            loss = np.mean(loss.asnumpy())

        cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1

        if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
            raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
                cb_params.cur_epoch_num, cur_step_in_epoch))
        if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
            print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num + int(self.has_trained_epoch),
                                                      cur_step_in_epoch, loss), flush=True)


def apply_eval(eval_param):
    eval_model = eval_param["model"]
    eval_ds = eval_param["dataset"]
    metrics_name = eval_param["metrics_name"]
    res = eval_model.eval(eval_ds)
    return res[metrics_name]


def set_parameter():
    target = config.device_target
    if target == "CPU":
        config.run_distribute = False

    # init context
    if config.mode_name == 'GRAPH':
        context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=config.save_graphs)
        # set_graph_kernel_context(target, config.net_name)
    else:
        context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False)

    if config.parameter_server:
        context.set_ps_context(enable_ps=True)


def init_weight(net):
    """init_weight"""
    for _, cell in net.cells_and_names():
        if isinstance(cell, nn.Conv2d):
            if config.conv_init == "XavierUniform":
                cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
                                                             cell.weight.shape,
                                                             cell.weight.dtype))
            elif config.conv_init == "TruncatedNormal":
                weight = conv_variance_scaling_initializer(cell.in_channels,
                                                           cell.out_channels,
                                                           cell.kernel_size[0])
                cell.weight.set_data(weight)
        if isinstance(cell, nn.Dense):
            if config.dense_init == "TruncatedNormal":
                cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
                                                             cell.weight.shape,
                                                             cell.weight.dtype))
            elif config.dense_init == "RandomNormal":
                in_channel = cell.in_channels
                out_channel = cell.out_channels
                weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
                weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=cell.weight.dtype)
                cell.weight.set_data(weight)


def init_lr(step_size):
    """init lr"""

    lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
                    warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
                    lr_decay_mode=config.lr_decay_mode)
    return lr


def init_loss_scale():
    loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    return loss


def run_eval(target, model, ckpt_save_dir, cb):
    """run_eval"""
    if config.run_eval:
        if config.eval_dataset_path is None or (not os.path.isdir(config.eval_dataset_path)):
            raise ValueError("{} is not a existing path.".format(config.eval_dataset_path))
        eval_dataset = create_dataset(dataset_path=config.eval_dataset_path, do_train=False,
                                      batch_size=config.batch_size, train_image_size=config.train_image_size,
                                      eval_image_size=config.eval_image_size,
                                      target=target, enable_cache=config.enable_cache,
                                      cache_session_id=config.cache_session_id)
        eval_param_dict = {"model": model, "dataset": eval_dataset, "metrics_name": "acc"}
        eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
                               eval_start_epoch=config.eval_start_epoch, save_best_ckpt=config.save_best_ckpt,
                               ckpt_directory=ckpt_save_dir, besk_ckpt_name="best_acc.ckpt",
                               metrics_name="acc")
        cb += [eval_cb]


def init_group_params(net):
    decayed_params = []
    no_decayed_params = []
    for param in net.trainable_params():
        if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
            decayed_params.append(param)
        else:
            no_decayed_params.append(param)

    group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
                    {'params': no_decayed_params},
                    {'order_params': net.trainable_params()}]
    return group_params


def set_save_ckpt_dir():
    ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
    return ckpt_save_dir


def train_net():
    """train net"""
    target = config.device_target
    set_parameter()
    dataset = create_dataset(dataset_path=config.data_path, do_train=True, repeat_num=1,
                             batch_size=config.batch_size,
                             distribute=config.run_distribute)
    step_size = dataset.get_dataset_size()
    net = resnet(class_num=config.class_num)
    if config.parameter_server:
        net.set_param_ps()

    init_weight(net=net)
    lr = Tensor(init_lr(step_size=step_size))
    # define opt
    group_params = init_group_params(net)
    opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
    loss = init_loss_scale()
    loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
    dist_eval_network = ClassifyCorrectCell(net) if config.run_distribute else None
    metrics = {"acc"}
    if config.run_distribute:
        metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=config.device_num)}
    if (config.net_name not in ("resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "se-resnet50")) or \
        config.parameter_server or target == "CPU":
        model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network)
    else:
        model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics,
                      amp_level="O2", boost_level=config.boost_mode, keep_batchnorm_fp32=False,
                      eval_network=dist_eval_network)

    # define callbacks
    time_cb = TimeMonitor(data_size=step_size)
    loss_cb = LossCallBack(config.has_trained_epoch)
    cb = [time_cb, loss_cb]
    ckpt_save_dir = set_save_ckpt_dir()
    if config.save_checkpoint:
        ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}]
        config_ck = CheckpointConfig(save_checkpoint_steps=500,
                                     keep_checkpoint_max=config.keep_checkpoint_max,
                                     append_info=ckpt_append_info)
        ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
        cb += [ckpt_cb]
    run_eval(target, model, ckpt_save_dir, cb)
    # train model

    dataset_sink_mode = (not config.parameter_server) and target != "CPU"
    config.pretrain_epoch_size = config.has_trained_epoch
    model.train(config.epoch_size - config.pretrain_epoch_size, dataset, callbacks=cb,
                sink_size=dataset.get_dataset_size(), dataset_sink_mode=dataset_sink_mode)

    if config.run_eval and config.enable_cache:
        print("Remember to shut down the cache server via \"cache_admin --stop\"")


if __name__ == '__main__':
    train_net()
